import torch
import numpy as np
import random
from model2 import MyModel
from torchvision.transforms import transforms
import cv2
from copy import deepcopy

def prediction(model:MyModel, img_path, save_path, device):
    model.eval()
    img = cv2.imread(img_path)
    img = cv2.resize(img, [256, 256])
    h, w = img.shape[:2]

    rand_x = random.randint(0, max(0, w-128-1))
    rand_y = random.randint(0, max(0, h-128-1))
    rand_w = random.randint(64, 128)
    rand_h = random.randint(64, 128)

    img[rand_y:min(h, rand_y+rand_h), rand_x:min(w, rand_x+rand_w), :] = 0
    mask = np.zeros([img.shape[0], img.shape[1], 1], dtype=np.uint8)
    mask[rand_y:rand_y+rand_h, rand_x:rand_x+rand_w, :] = 255
    
    inp = transforms.ToTensor()(img)
    mask = transforms.ToTensor()(mask)
    inp = inp.unsqueeze(0)
    mask = mask.unsqueeze(0)
    
    with torch.no_grad():
        out = model(inp.to(device), mask.to(device))

    out = torch.squeeze(out)
    img2 = out.to('cpu')
    img2 = img2.numpy()
    img2 = (img2 * 255).astype(np.uint8)
    img2 = np.transpose(img2, (1, 2, 0))
    # img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR)

    # 拼接输出和输入
    h, w = img.shape[:2]
    h2, w2 = img2.shape[:2]
    m1 = np.zeros((max(h, h2), max(w, w2), 3), dtype=np.uint8)
    m2 = np.zeros((max(h, h2), max(w, w2), 3), dtype=np.uint8)
    m1[:h, :w, :] = img
    m2[:h2, :w2, :] = img2
    cat_img = img = np.concatenate([m1, m2], axis=1)

    cv2.imwrite(save_path, cat_img)

